import torch
import tqdm
from logs import log_model_performance, log_decoder_cosine_sim_quantiles
import multiprocessing as mp
from queue import Empty
import wandb
import json
import os
import json
from pathlib import Path

def save_checkpoint_mp(sae, cfg, step):

    save_dir = f"checkpoints/{cfg['name']}_{step}"
    os.makedirs(save_dir, exist_ok=True)

    # Save model state
    sae_path = os.path.join(save_dir, "sae.pt")
    torch.save(sae.state_dict(), sae_path)

    # Prepare config for JSON serialization
    json_safe_cfg = {}
    for key, value in cfg.items():
        if isinstance(value, (int, float, str, bool, type(None))):
            json_safe_cfg[key] = value
        elif isinstance(value, (torch.dtype, type)):
            json_safe_cfg[key] = str(value)
        else:
            json_safe_cfg[key] = str(value)

    # Save config
    config_path = os.path.join(save_dir, "config.json")
    with open(config_path, "w") as f:
        json.dump(json_safe_cfg, f, indent=4)

    print(f"Model and config saved at step {step} in {save_dir}")
    return save_dir, sae_path, config_path


        
def train_meta_saes(saes, meta_saes, cfgs):
    
    def new_wandb_process(config, log_queue, entity, project):
        run = wandb.init(
            entity=entity, 
            project=project, 
            config=config, 
            name=config["name"]
        )
        
        while True:
            try:
                log = log_queue.get(timeout=1)
                if log == "DONE":
                    break
                if isinstance(log, dict) and log.get("checkpoint"):
                    artifact = wandb.Artifact(
                        name=f"{config['name']}_{log['step']}",
                        type="model",
                        description=f"Model checkpoint at step {log['step']}",
                    )
                    save_dir = log["save_dir"]
                    artifact.add_file(os.path.join(save_dir, "sae.pt"))
                    artifact.add_file(os.path.join(save_dir, "config.json"))
                    run.log_artifact(artifact)
                else:
                    run.log(log)
            except Empty:
                continue
                
        run.finish()
    
    num_batches = int(cfgs[0]["num_tokens"] // cfgs[0]["batch_size"])
    print(f"Number of batches: {num_batches}")
    
    optimizers = [torch.optim.Adam(meta_sae.parameters(), lr=cfg["lr"], betas=(cfg["beta1"], cfg["beta2"])) 
                 for meta_sae, cfg in zip(meta_saes, cfgs)]

    wandb_processes = []
    log_queues = []
    
    for i, cfg in enumerate(cfgs):
        log_queue = mp.Queue()
        log_queues.append(log_queue)
        wandb_config = cfg
        wandb_process = mp.Process(
            target=new_wandb_process,
            args=(wandb_config, log_queue, cfg.get("wandb_entity", ""), cfg["wandb_project"]),
        )
        wandb_process.start()
        wandb_processes.append(wandb_process)

    for idx, (sae, meta_sae, cfg, optimizer) in enumerate(zip(saes, meta_saes, cfgs, optimizers)):

        metrics = {
            'explained_variance': [],
            'steps': []
        }

        pbar = tqdm.trange(num_batches)
        W_dec = sae.W_dec.detach()
        
        for jdx, i in enumerate(pbar):
            torch.manual_seed(jdx)
            random_indices = torch.randint(0, W_dec.size(0), (cfg["batch_size"], ))
            batch = W_dec[random_indices]
            sae_output = meta_sae(batch)
            loss = sae_output["loss"]
            
            if i % cfg["perf_log_freq_base_metrics"] == 0:
                log_dict = {
                    k: v.item() if isinstance(v, torch.Tensor) and v.dim() == 0 else v
                    for k, v in sae_output.items() if isinstance(v, (int, float)) or 
                    (isinstance(v, torch.Tensor) and v.dim() == 0)
                }
                
                total_variance = torch.var(batch, dim=0).sum().item()
                reconstruction_error = batch - sae_output["sae_out"].reshape(-1, meta_sae.config["act_size"])
                unexplained_variance = torch.var(reconstruction_error, dim=0).sum().item()
                explained_variance_ratio = (total_variance - unexplained_variance) / total_variance
                metrics['steps'].append(i)
                metrics['explained_variance'].append(explained_variance_ratio)
                log_dict['explained_variance_ratio'] = explained_variance_ratio
                log_queues[idx].put(log_dict)
                
            pbar.set_postfix({
                f"Loss_{idx}": f"{loss.item():.4f}", 
                f"L0_{idx}": f"{sae_output['l0_norm']:.4f}",
                f"L2_{idx}": f"{sae_output['l2_loss']:.4f}", 
                f"L1_{idx}": f"{sae_output['l1_loss']:.4f}",
            })
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(meta_sae.parameters(), cfg["max_grad_norm"])
            meta_sae.make_decoder_weights_and_grad_unit_norm()
            optimizer.step()
            optimizer.zero_grad()

        output_dir = Path(f"results/{cfg['model_name']}")
        output_dir.mkdir(parents=True, exist_ok=True)
        
        with open(output_dir/'metrics.json', 'w') as f:
            json.dump(metrics, f, indent=2)

        print(f"Saved metrics for {cfg['model_name']} to {output_dir/'metrics.json'}")

    for queue in log_queues:
        queue.put("DONE")
    for process in wandb_processes:
        process.join()

  